#include <seminix/syscall.h>
#include <seminix/param.h>
#include <seminix/slab.h>
#include <seminix/init.h>
#include <cap/cap.h>
#include <cap/endpoint.h>
#include <cap/cnode.h>
#include <cap/ipc_buffer.h>

static struct kmem_cache *cap_endpoint_cachep;
static struct kmem_cache *endpoint_cachep;

static __init int endpoint_cap_init(void)
{
    cap_endpoint_cachep = KMEM_CACHE(cap_endpoint, SLAB_PANIC);
    endpoint_cachep = KMEM_CACHE(endpoint, SLAB_PANIC);

    return 0;
}
userver_initcall(endpoint_cap_init)

static struct cap_endpoint *__endpoint_cap_create(void)
{
    struct cap_endpoint *new_endpoint;

    new_endpoint = kmem_cache_alloc(cap_endpoint_cachep, GFP_KERNEL | GFP_ZERO);
    if (!new_endpoint)
        return ERR_PTR(-SERRNO_ENOMEM);

    return new_endpoint;
}

static cap_t *endpoint_cap_create(seminix_object_t *object)
{
    int ret;
    struct cap_endpoint *new_endpoint;
    struct endpoint *endpoint;

    new_endpoint = __endpoint_cap_create();
    if (IS_ERR(new_endpoint))
        return (cap_t *)new_endpoint;

    ret = -SERRNO_ENOMEM;
    endpoint = kmem_cache_alloc(endpoint_cachep, GFP_KERNEL | GFP_ZERO);
    if (!endpoint)
        goto free_endpoint;

    INIT_LIST_HEAD(&endpoint->recv_list);
    INIT_LIST_HEAD(&endpoint->send_list);
    INIT_LIST_HEAD(&endpoint->reply_list);
    spin_lock_init(&endpoint->lock);

    new_endpoint->badge = object->endpoint.badge;
    new_endpoint->status = IPC_IDLE;
    init_waitqueue_head(&new_endpoint->wait);
    new_endpoint->endpoint = endpoint;

    return CAP_REF(new_endpoint);

free_endpoint:
    kmem_cache_free(endpoint_cachep, new_endpoint);
    return ERR_PTR(ret);
}

static void endpoint_cap_delete(cap_t *cap)
{
    cap_endpoint_t *cap_endpoint = (cap_endpoint_t *)cap;

    assert(list_empty(&cap->child));
    assert(cap_get_cap_type(cap) == cap_endpoint_cap);

    kmem_cache_free(endpoint_cachep, cap_endpoint->endpoint);
    kmem_cache_free(cap_endpoint_cachep, cap_endpoint);
}

static cap_t *endpoint_cap_dup(cap_t *cap)
{
    struct cap_endpoint *new_endpoint;

    BUG_ON(cap_get_cap_type(cap) != cap_endpoint_cap);

    new_endpoint = __endpoint_cap_create();
    if (IS_ERR(new_endpoint))
        return (cap_t *)new_endpoint;

    new_endpoint->badge = 0;
    new_endpoint->status = IPC_IDLE;
    init_waitqueue_head(&new_endpoint->wait);
    new_endpoint->endpoint = CAP_ENDPOINT_PTR(cap)->endpoint;

    return CAP_REF(new_endpoint);
}

static void endpoint_cap_revoke(cap_t *cap)
{
    BUG_ON(!list_empty(&cap->child));
    BUG_ON(cap_get_cap_type(cap) != cap_endpoint_cap);

    kmem_cache_free(cap_endpoint_cachep, cap);
}

const struct cap_ops endpoint_cap_ops __ro_after_init = {
    .cap_create = endpoint_cap_create,
    .cap_delete = endpoint_cap_delete,
    .cap_dup = endpoint_cap_dup,
    .cap_revoke = endpoint_cap_revoke,
};

static int endpoint_check_badge(cap_t *cap, unsigned long badge)
{
    cap_t *this_cap = cap->parent;
    cap_endpoint_t *cap_endpoint = CAP_ENDPOINT_PTR(cap);

    if (cap_endpoint->badge != 0)
        return -SERRNO_EAGAIN;

    while (this_cap) {
        if (cap_get_cap_type(this_cap) != cap_endpoint_cap)
            return 0;
        if (CAP_ENDPOINT_PTR(this_cap)->badge == badge)
            return -SERRNO_EEXIST;
        this_cap = this_cap->parent;
    }
    BUG();
}

SYSCALL_DEFINE3(epset_badge, int, ep, int, parent, unsigned long, badge)
{
    int ret;
    cap_t *cap, *parent_cap;

    parent_cap = cnode_capget(parent, cap_endpoint_cap);
    if (IS_ERR(parent_cap))
        return PTR_ERR(parent_cap);
    cap = cnode_capget(ep, cap_endpoint_cap);
    if (IS_ERR(cap)) {
        ret = PTR_ERR(cap);
        goto put_parent;
    }

    if (!cap_is_parent(parent_cap, cap)) {
        ret = -SERRNO_EILLEGAL;
        goto put_cap;
    }

    ret = endpoint_check_badge(cap, badge);
    if (ret)
        goto put_cap;

    CAP_ENDPOINT_PTR(cap)->badge = badge;
put_cap:
    cnode_capput(cap);
put_parent:
    cnode_capput(parent_cap);
    return ret;
}

static int endpoint_sendipc_fastpath(cap_endpoint_t *cap_endpoint, int op, seminix_message_t send_mess, ktime_t timeout)
{
    int ret = 0, args;
    cap_endpoint_t *recv_cap_endpoint;
    endpoint_t *endpoint = cap_endpoint->endpoint;
    struct tcb *tsk;

    recv_cap_endpoint = list_first_entry(&endpoint->recv_list, cap_endpoint_t, list);
    assert(recv_cap_endpoint->status == IPC_RECV);
    list_del(&recv_cap_endpoint->list);
    tsk = recv_cap_endpoint->tsk;
    if (op == IPC_SEND_NOTIFY) {
        set_task_notify(tsk, 0, cap_endpoint->badge, send_mess);
        set_task_message(tsk, seminix_message_new(0, 0, 1));
        recv_cap_endpoint->status = IPC_IDLE;
        wake_up(&recv_cap_endpoint->wait);
        return 0;
    }
    set_task_message(tsk, send_mess);
    args = seminix_message_get_length(send_mess);
    for (int i = 0; i < args; i++)
        set_task_mr(tsk, i, get_current_mr(i));
    set_task_badge(tsk, cap_endpoint->badge);
    if (op == IPC_SEND_REPLY || op == IPC_NBSEND_REPLY) {
        set_task_reply(tsk, true);
        cap_endpoint->status = op;
        list_add(&cap_endpoint->list, &endpoint->reply_list);
        recv_cap_endpoint->status = IPC_IDLE;
        wake_up(&recv_cap_endpoint->wait);
        ret = wait_event_timeout(cap_endpoint->wait, cap_endpoint->status != op, timeout);
        if (!ret)
            return -SERRNO_ETIMEOUT;
        if (cap_endpoint->status == IPC_INTR)
            ret = -SERRNO_EINTR;
    }
    return ret;
}

static int endpoint_sendipc(cap_endpoint_t *cap_endpoint, int op, seminix_message_t send_mess, ktime_t timeout)
{
    int ret;
    endpoint_t *endpoint = cap_endpoint->endpoint;

    assert(cap_endpoint->tsk == current);
    spin_lock(&endpoint->lock);
    if (list_empty(&endpoint->recv_list)) {
        if (op == IPC_NBSEND || op == IPC_NBSEND_REPLY) {
            spin_unlock(&endpoint->lock);
            return -SERRNO_ENORECV;
        }
        if (op == IPC_SEND_NOTIFY) {
            if (unlikely(endpoint->notify_cur == IPC_NOTIFY_MAX)) {
                WARN(1, "badge %pa:send endpoint notify overflow, discard the message\n", &cap_endpoint->badge);
                spin_unlock(&endpoint->lock);
                return -SERRNO_EOVERFLOW;
            }
            endpoint->notify[endpoint->notify_cur].badge = cap_endpoint->badge;
            endpoint->notify[endpoint->notify_cur].message = send_mess;
            endpoint->notify_cur++;
            spin_unlock(&endpoint->lock);
            return 0;
        }
        list_add_tail(&cap_endpoint->list, &endpoint->send_list);
        cap_endpoint->status = op;
        spin_unlock(&endpoint->lock);
        ret = wait_event_timeout(cap_endpoint->wait, cap_endpoint->status != op, timeout);
        if (!ret)
            return -SERRNO_ETIMEOUT;
        if (cap_endpoint->status == IPC_INTR)
            return -SERRNO_EINTR;
        return 0;
    }

    ret = endpoint_sendipc_fastpath(cap_endpoint, op, send_mess, timeout);
    spin_unlock(&endpoint->lock);
    return ret;
}

static int utime_get(struct __kernel_timespec __user *utime, ktime_t *t)
{
    struct timespec64 ts;

    if (get_timespec64(&ts, utime))
        return -SERRNO_EFAULT;

    if (!timespec64_valid(&ts))
        return -SERRNO_EINVAL;

    *t = timespec64_to_ktime(ts);
    *t = ktime_add_safe(ktime_get(), *t);
    return 0;
}

SYSCALL_DEFINE5(sendipc, int, ep, seminix_message_t, send_mess, int, op,
    seminix_message_t __user *, recv_mess, struct __kernel_timespec __user *, utime)
{
    int ret;
    cap_t *cap;
    ktime_t t = KTIME_MAX;
    seminix_message_t mess;

    cap = cnode_capget(ep, cap_endpoint_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    if (CAP_ENDPOINT_PTR(cap)->badge == 0) {
        ret = -SERRNO_EINACTIVE;
        goto out;
    }

    if (!cap_can_send(cap)) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }

    if (utime && (op == IPC_SEND || op == IPC_SEND_REPLY ||
              op == IPC_NBSEND_REPLY)) {
        ret = utime_get(utime, &t);
        if (ret)
            goto out;
    }

    ret = endpoint_sendipc(CAP_ENDPOINT_PTR(cap), op, send_mess, t);
    if (ret)
        goto out;

    if (op == IPC_SEND_REPLY || op == IPC_NBSEND_REPLY) {
        mess = get_current_message();
        if (put_user(mess, recv_mess)) {
            ret = -SERRNO_EFAULT;
            goto out;
        }
    }
out:
    cnode_capput(cap);
    return ret;
}

static int endpoint_recvipc_fastpath(cap_endpoint_t *cap_endpoint, int op,
    seminix_message_t *recv_mess, bool *reply, unsigned long *badge, ktime_t timeout)
{
    int args;
    cap_endpoint_t *send_cap_endpoint;
    endpoint_t *endpoint = cap_endpoint->endpoint;
    struct tcb *tsk;

    send_cap_endpoint = list_first_entry(&endpoint->send_list, cap_endpoint_t, list);
    list_del(&send_cap_endpoint->list);
    tsk = send_cap_endpoint->tsk;
    *recv_mess = get_task_message(tsk);
    *reply = get_task_reply(tsk);
    *badge = get_task_badge(tsk);
    args = seminix_message_get_length(*recv_mess);
    for (int i = 0; i < args; i++)
        set_current_mr(i, get_task_mr(tsk, i));
    if (send_cap_endpoint->status == IPC_SEND_REPLY ||
        send_cap_endpoint->status == IPC_NBSEND_REPLY)
        list_add_tail(&send_cap_endpoint->list, &endpoint->reply_list);
    else {
        send_cap_endpoint->status = IPC_IDLE;
        wake_up(&send_cap_endpoint->wait);
    }
    return 0;
}

static int endpoint_recvipc(cap_endpoint_t *cap_endpoint, int op,
    seminix_message_t *recv_mess, bool *reply, unsigned long *badge, ktime_t timeout)
{
    int ret;
    endpoint_t *endpoint = cap_endpoint->endpoint;

    assert(cap_endpoint->tsk == current);
    spin_lock(&endpoint->lock);
    if (endpoint->irqnotify_cur != 0) {
        for (int i = 0; i < endpoint->irqnotify_cur; i++)
            set_current_irqnotify(i, endpoint->irqnotify[i]);
        *recv_mess = seminix_message_new(0, 0, endpoint->irqnotify_cur);
        *reply = false;
        *badge = 0;
        endpoint->irqnotify_cur = 0;
        spin_unlock(&endpoint->lock);
        return SEMINIX_NOTIFY_IRQ;
    }
    if (endpoint->notify_cur != 0) {
        for (int i = 0; i < endpoint->notify_cur; i++)
            set_current_notify(i, endpoint->notify[i].badge, endpoint->notify[i].message);
        *recv_mess = seminix_message_new(0, 0, endpoint->notify_cur);
        *reply = false;
        *badge = 0;
        endpoint->notify_cur = 0;
        spin_unlock(&endpoint->lock);
        return SEMINIX_NOTIFY_NORMAL;
    }
    if (list_empty(&endpoint->send_list)) {
        if (op == IPC_NBRECV) {
            spin_unlock(&endpoint->lock);
            return -SERRNO_ENOSEND;
        }
        list_add_tail(&cap_endpoint->list, &endpoint->recv_list);
        cap_endpoint->status = op;
        spin_unlock(&endpoint->lock);
        ret = wait_event_timeout(cap_endpoint->wait, cap_endpoint->status != op, timeout);
        if (!ret)
            return -SERRNO_ETIMEOUT;
        if (cap_endpoint->status == IPC_INTR)
            return -SERRNO_EINTR;
        *recv_mess = get_current_message();
        *badge = get_current_badge();
        *reply = get_current_reply();
        return 0;
    }
    ret = endpoint_recvipc_fastpath(cap_endpoint, op, recv_mess, reply, badge, timeout);
    spin_unlock(&endpoint->lock);
    return ret;
}

SYSCALL_DEFINE6(recvipc, int, ep, int, op,
    seminix_message_t __user *, recv_mess, bool __user *, rep,
    unsigned long __user *, sender, struct __kernel_timespec __user *, utime)
{
    int ret;
    seminix_message_t mess;
    bool reply;
    unsigned long badge;
    cap_t *cap;
    ktime_t t = KTIME_MAX;

    cap = cnode_capget(ep, cap_endpoint_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    if (CAP_ENDPOINT_PTR(cap)->badge == 0) {
        ret = -SERRNO_EINACTIVE;
        goto out;
    }

    if (!cap_can_recv(cap)) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }

    if (utime && (op == IPC_SEND || op == IPC_SEND_REPLY ||
              op == IPC_NBSEND_REPLY)) {
        ret = utime_get(utime, &t);
        if (ret)
            goto out;
    }

    ret = endpoint_recvipc(CAP_ENDPOINT_PTR(cap), op, &mess, &reply, &badge, t);
    if (ret < 0)
        return ret;

    if (put_user(mess, recv_mess)) {
        ret = -SERRNO_EFAULT;
        goto out;
    }
    if (put_user(reply, rep)) {
        ret = -SERRNO_EFAULT;
        goto out;
    }
    if (put_user(badge, sender)) {
        ret = -SERRNO_EFAULT;
        goto out;
    }
out:
    cnode_capput(cap);
    return ret;
}

static int endpoint_replyipc(cap_endpoint_t *cap_endpoint, unsigned long badge, seminix_message_t reply_mess)
{
    endpoint_t *endpoint = cap_endpoint->endpoint;
    cap_endpoint_t *reply_endpoint;

    spin_lock(&endpoint->lock);
    list_for_each_entry(reply_endpoint, &endpoint->reply_list, list) {
        if (reply_endpoint->badge == badge) {
            list_del(&reply_endpoint->list);
            spin_unlock(&endpoint->lock);
            reply_endpoint->status = IPC_IDLE;
            wake_up(&reply_endpoint->wait);
            return 0;
        }
    }
    spin_unlock(&endpoint->lock);
    return -SERRNO_ENOREPLY;
}

SYSCALL_DEFINE3(replyipc, int, ep, unsigned long, badge,
    seminix_message_t, reply_mess)
{
    int ret;
    cap_t *cap;

    cap = cnode_capget(ep, cap_endpoint_cap);
    if (IS_ERR(cap))
        return PTR_ERR(cap);

    if (CAP_ENDPOINT_PTR(cap)->badge == 0) {
        ret = -SERRNO_EINACTIVE;
        goto out;
    }

    if (!cap_can_recv(cap)) {
        ret = -SERRNO_EILLEGAL;
        goto out;
    }

    ret = endpoint_replyipc(CAP_ENDPOINT_PTR(cap), badge, reply_mess);
out:
    cnode_capput(cap);
    return ret;
}
